-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update for near perfect correlation with pystoi
#9
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a great improvement, thank you very much !
Few things,
- Can you try it on a GPU to make sure that it runs ?
- Can you sum up the changes you had to make ?
- And adress the minor comments I left ?
Thanks again, great work !
Also, feel free to change the graphs in the README within the PR, it's worth it 👀 |
Will do
This is inspired from mpariente/pystoi#28.
Done |
I guess you didn't yet commit the changes, but it all sounds good. |
No not yet, I am working on it. Turns out some fixes are required to run on CUDA. Will probably finish after the weekend. |
Done.
Edit: pressed "Comment" before finishing my message |
Thanks again, the benchmark is on GPU or CPU ? |
That last benchmark was on GPU, on a V100. |
What is the |
Which part of the code adds the most time ? |
The The benchmarking script can be found here. Namely the measured function is def to_time():
loss = criterion(x, y).mean()
loss.backward(retain_graph=True) where I can try and profile the code tomorrow and see what's taking the most time. I profiled on CPU and there it was the |
Also I ran this on a computing cluster and did not make sure both codes ran on the same host. I can try and fix that too tomorrow. |
Okay I made sure to run it again on the same host and processing is indeed about ~x2 times longer. Below are full master:
this PR:
|
return torch.stack([ | ||
pad(xi[mi], (0, 0, 0, len(xi) - mi.sum())) for xi, mi in zip(x, mask) | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the list comprehension which takes the most time, and it makes sense that it does.
Thanks for running the profiler ! I think a x2 in computation time still worth the improvement, specially that it will be negligible when compared to the NN's training time. |
I merged this PR in |
I also think it's worth the improvement. The time increase also happens inside
Sure. I will also try and think of a way to avoid this list comprehension. |
I don't think so. Backward through the loss is the first operation of the overall backward, the backward time of the whole architecture is independent of the loss, you just need the gradients as input. |
Latest commit now reflects mpariente/pystoi#33. I made new correlation plots and checked it is still strong. I can commit them if you want. Dunno if we should mention the correlation is against the latest commit in pystoi, since no new pystoi version was released after mpariente/pystoi#33. The correlation against current pystoi version i.e. pre-mpariente/pystoi#33 would not be as strong. |
Sure, add a commit tag somewhere in a comment.
Do you think the difference is that big ? |
Could you make a correlation plot between the former and new version of pySTOI please ? |
Thanks ! It's roughly the same, right ? |
Yes, I can't see the difference unless I superimpose the plots and keep Alt-Tabbing |
So please add a comment in the code, and the Readme, and we'll merge this. |
Done. Also updated the docstring to sound less dissuasive about using this vs. pystoi, while still emphasizing on the difference. Maybe you can have a look. |
This PR significantly improves correlations with
pystoi
at little performance cost. Whenuse_vad=True
, the correlation is near-perfect. The only reason it's not truly perfect is because of the different resampling technique. Changes are extensive and I am ready to discuss them with you.This also fixes two warnings:
torch.stft
inNegSTOILoss.stft
becausereturn_complex=False
is deprecatedtorchaudio.transforms.Resample
inNegSTOILoss.__init__
becauseresampling_method=sinc_interpolation
is deprecated and should now besinc_interp_hann
.Correlations
Below are correlation plots obtained with this fix. These were obtained using mixtures I generated which I can provide if needed. As you can see when
use_vad=True
we obtain near-perfect correlation. I did not include these plots in this PR because I didn't know if the same mixtures should be used to recreate the plots.For reference below is
master
according toREADME.md
vs. this PR w/ VAD. Quite an improvement IMO.Processing times
Below I compare processing times between
master
this PR. The evaluated function is aNegSTOILoss.forward
pass followed by aloss.backward
pass where thenn.Module
is theTestNet
intest_torchstoi.py
. Whenuse_vad=True
, processing is a bit slower, but worth the accuracy improvement IMO. Withuse_vad=False
, it seems it's actually slightly faster. These were obtained on a HP Elitebook G6 with an Intel(R) Core(TM) i7-8665U CPU @ 1.90GHz. I had no free GPU at hand when testing so I couldn't try on GPU.